public class CountNodes {

    public int countNodes(TreeNode root) {
        return cpuntN(root);
    }

    public int cpuntN(TreeNode node) {
        if (node == null) return 0;
        int left = cpuntN(node.left);
        int right = cpuntN(node.right);
        int count = left + right + 1;
        return count;
    }
}

